"""
description to follow
"""

# pylint: disable=no-name-in-module

from abc import ABC, abstractmethod
from typing import List, Tuple

import numpy as np
import torch
from torch import Tensor
from tqdm import tqdm

from diffusion_bandit.diffusion import DiffusionProcess

# pylint:disable=missing-class-docstring
# pylint:disable=missing-function-docstring


def get_sampler(name, **kwargs):
    if name == "pc":
        return PredictorCorrectorSampler(**kwargs)
    else:
        raise NotImplementedError(f"Sampler '{name}' is not implemented.")


class Sampler(ABC):
    def __init__(
        self,
        shape: Tuple[int, ...],
        diffusion_process: DiffusionProcess,
        device: str = "cpu",
        **kwargs,
    ):
        self.shape = shape
        self.device = device
        self.diffusion_process = diffusion_process

    @abstractmethod
    def sample(
        self,
        score_model: torch.nn.Module,
        batch_size: int = 64,
        num_steps: int = 1000,
        eps: float = 1e-5,
    ) -> List[Tensor]:
        pass


class PredictorCorrectorSampler(Sampler):
    def __init__(
        self,
        shape: Tuple[int, ...],
        diffusion_process: DiffusionProcess,
        snr: float = 0.16,
        device: str = "cuda",
        alpha_max: float = 30.0,
        **kwargs,
    ):
        super().__init__(shape, diffusion_process, device)
        self.snr = snr
        self.alpha_max = alpha_max

    def sample(
        self,
        score_model: torch.nn.Module,
        batch_size: int = 64,
        num_steps: int = 500,
        eps: float = 1e-5,
        return_full: bool = False,
    ) -> List[Tensor]:

        time = torch.ones(batch_size, device=self.device)
        init_x = torch.randn(
            batch_size, *self.shape, device=self.device
        ) * self.diffusion_process.marginal_prob_std(time).view(
            -1, *([1] * len(self.shape))
        )
        time_steps = np.linspace(1.0, eps, num_steps)
        step_size = time_steps[0] - time_steps[1]
        iter_x = init_x
        if return_full:
            paths: List[Tensor] = []
            paths.append(iter_x)

        with torch.no_grad():
            for time_step in tqdm(time_steps):
                if score_model.type == "combined":
                    time_tensor = torch.tensor(time_step).unsqueeze(-1)
                    new_alpha = (
                        self.alpha_max
                        * (
                            1
                            - self.diffusion_process.marginal_prob_std(time=time_tensor)
                        ).item()
                    )
                    score_model.set_alpha(new_alpha)

                batch_time_step = (
                    torch.ones((batch_size, 1), device=self.device) * time_step
                )

                # Corrector step (Langevin MCMC)
                grad = score_model(batch_time_step, iter_x)
                # grad = grad/max(torch.linalg.norm(grad), 1)
                assert (
                    grad.shape == iter_x.shape
                ), f"Gradient shape {grad.shape} doesn't match x shape {iter_x.shape}"
                grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
                noise_norm = np.sqrt(np.prod(iter_x.shape[1:]))
                langevin_step_size = 2 * (self.snr * noise_norm / grad_norm) ** 2
                iter_x = (
                    iter_x
                    + langevin_step_size * grad
                    + torch.sqrt(2 * langevin_step_size) * torch.randn_like(iter_x)
                )

                # Predictor step (Euler-Maruyama)
                beta_t = self.diffusion_process.beta(batch_time_step) * step_size
                iter_x = 1 / torch.sqrt(1 - beta_t) * (
                    iter_x + 0.5 * beta_t * grad
                ) + torch.sqrt(beta_t) * torch.randn_like(iter_x)
                if return_full:
                    paths.append(iter_x)
            if return_full:
                return paths
            else:
                return [iter_x]
